package downloader

import (
	"context"
	"gitee.com/Luna-CY/hui-hui/internal/util/goroutine"
	"github.com/cheggaaa/pb/v3"
	"io"
	"net/http"
	"os"
)

type writer struct {
	w  io.Writer
	ch chan int64
}

func (cls *writer) Write(p []byte) (n int, err error) {
	n, err = cls.w.Write(p)
	cls.ch <- int64(n)

	return
}

func (cls *writer) Close() (err error) {
	close(cls.ch)

	return
}

// DownloadToTempFile 下载文件到临时文件
// 调用此方法的函数需要负责在使用后删除该文件
func DownloadToTempFile(ctx context.Context, url string, callback func(total int64, current int64)) (int64, *os.File, error) {
	var request, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
	if nil != err {
		return 0, nil, err
	}

	response, err := http.DefaultClient.Do(request)
	if nil != err {
		return 0, nil, err
	}

	defer func() {
		_ = response.Body.Close()
	}()

	temp, err := os.CreateTemp("", "")
	if nil != err {
		return 0, nil, err
	}

	var progress = pb.Full.Start64(response.ContentLength)
	defer progress.Finish()

	var w = &writer{w: progress.NewProxyWriter(temp), ch: make(chan int64, 100)}
	defer func() {
		_ = w.Close()
	}()

	goroutine.Go(func() {
		var current int64
		for n := range w.ch {
			current += n
			callback(response.ContentLength, current)
		}
	})

	if _, err := io.Copy(w, response.Body); nil != err {
		defer func() {
			_ = temp.Close()
			_ = os.RemoveAll(temp.Name())
		}()

		return 0, nil, err
	}

	if _, err := temp.Seek(0, io.SeekStart); nil != err {
		defer func() {
			_ = temp.Close()
			_ = os.RemoveAll(temp.Name())
		}()

		return 0, nil, err
	}

	return response.ContentLength, temp, nil
}
